from typing import List
from playhouse.pool import PooledSqliteDatabase
from peewee import (
    AutoField,
    CharField,
    BooleanField,
    IntegerField,
    Model,
    SqliteDatabase as PeeweeSqliteDatabase,
    chunked, ModelSelect, ModelDelete, FloatField
)

from mquant.trader.config import DB_PATH
from mquant.trader.object import AccountModelData, RiskManagerModelData

db: PeeweeSqliteDatabase = PooledSqliteDatabase(DB_PATH, timeout=10, max_connections=10,check_same_thread=False)


class AccountModelDataTable(Model):
    """账户数据类"""
    id: AutoField = AutoField()
    username: str = CharField()
    account: str = CharField()
    password: str = CharField()
    broker: str = CharField(null=True)
    trade_server: str = CharField()
    quotation_server: str = CharField()
    product_name: str = CharField(null=True)
    authorization_code: str = CharField(null=True)
    init_balance: float = FloatField(null=True)

    class Meta:
        table_name = "account"
        database: PeeweeSqliteDatabase = db
        indexes: tuple = ((("username", "account"), True),)


class RiskManagerModelTable(Model):
    id: AutoField = AutoField()
    username: str = CharField()
    account: str = CharField()
    max_loss: float = FloatField()
    sleep: int = IntegerField()

    class Meta:
        table_name = "riskManager"
        database: PeeweeSqliteDatabase = db
        indexes: tuple = ((("username", "account"), True),)

    def get_account(self):
        return AccountModelDataTable.get(AccountModelDataTable.account == self.account).account

    def get_username(self):
        return AccountModelDataTable.get(AccountModelDataTable.account == self.account).username


class SqliteDatabase:
    """SQLite数据库接口"""

    def __init__(self) -> None:
        """"""
        self.db: PeeweeSqliteDatabase = db
        with self.db.connection_context():
            self.db.create_tables([AccountModelDataTable, RiskManagerModelTable])

    def save_account_data(self, accounts: List[AccountModelData]) -> bool:
        """保存账户信息数据"""
        data: list = []

        for account in accounts:
            d: dict = account.__dict__
            data.append(d)
        # 使用upsert操作将数据更新到数据库中
        with self.db.atomic():
            for c in chunked(data, 5):
                AccountModelDataTable.insert_many(c).on_conflict_replace().execute()

        return True

    def save_riskManager_data(self, riskManagers: List[RiskManagerModelData]) -> bool:
        """保存风控信息数据"""
        data: list = []

        for riskManager in riskManagers:
            d: dict = riskManager.__dict__
            data.append(d)
        # 使用upsert操作将数据更新到数据库中
        with self.db.atomic():
            for c in chunked(data, 5):
                RiskManagerModelTable.insert_many(c).on_conflict_replace().execute()
        return True

    @staticmethod
    def load_account_data() -> List[AccountModelData]:
        """读取账户数据"""
        s: ModelSelect = (
            AccountModelDataTable.select()
        )

        accounts: List[AccountModelData] = []
        for _ in s:
            account: AccountModelData = AccountModelData(
                id=_.id,
                username=_.username,
                account=_.account,
                password=_.password,
                broker=_.broker,
                trade_server=_.trade_server,
                quotation_server=_.quotation_server,
                product_name=_.product_name,
                authorization_code=_.authorization_code,
                init_balance=_.init_balance
            )
            accounts.append(account)

        return accounts

    @staticmethod
    def load_riskManager_data() -> List[RiskManagerModelData]:
        """读取账户数据"""
        s: ModelSelect = (
            RiskManagerModelTable.select()
        )
        riskManagers: List[RiskManagerModelData] = []
        for _ in s:
            risk_manager: RiskManagerModelData = RiskManagerModelData(
                id=_.id,
                username=_.username,
                account=_.account,
                max_loss=_.max_loss,
                sleep=_.sleep
            )
            riskManagers.append(risk_manager)
        return riskManagers

    @staticmethod
    def load_riskManager_data_by_account(account: str) -> RiskManagerModelData:
        """读取账户数据"""
        _: RiskManagerModelData = (
            RiskManagerModelTable.select().where(RiskManagerModelTable.account == account)
        )[0]

        risk_manager: RiskManagerModelData = RiskManagerModelData(
            id=_.id,
            username=_.username,
            account=_.account,
            max_loss=_.max_loss,
            sleep=_.sleep
        )

        return risk_manager

    @staticmethod
    def delete_account_data(
            account: str,
    ) -> int:
        """删除账户数据"""
        d: ModelDelete = AccountModelDataTable.delete().where(
            (AccountModelDataTable.account == account)
        )
        count: int = d.execute()

        return count

    @staticmethod
    def delete_riskManager_data(
            account: str,
    ) -> int:
        """删除账户数据"""
        d: ModelDelete = RiskManagerModelTable.delete().where(
            (RiskManagerModelTable.account == account)
        )
        count: int = d.execute()

        return count

    def quit(self):
        self.db.close()
